{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load file: ../feature/DWTfeature_S1.mat...[ok]:206751\n",
      "load file: ../feature/DWTfeature_S2.mat...[ok]:206659\n",
      "(97210, 12, 4) (97210, 1)\n",
      "(105147, 12, 4) (97210, 1)\n"
     ]
    }
   ],
   "source": [
    "from scipy import io\n",
    "import numpy as np\n",
    "\n",
    "def get_feature_dict(filename):\n",
    "    \"\"\"将ninapro_feature的MAT文件加载为字典\n",
    "\n",
    "    Args:\n",
    "        path: mat文件路径\n",
    "\n",
    "    Returns:\n",
    "        数据集字典\n",
    "        [feat_set, featStim, featRep]\n",
    "    \"\"\"\n",
    "    # 读取MAT文件\n",
    "    print('load file: ' + filename + '...', end= '', flush=True)\n",
    "    dict_feature=io.loadmat(filename)\n",
    "    if (dict_feature != ()):\n",
    "        #print(ninapro_data.keys())\n",
    "        print('[ok]:%d'%(len(dict_feature['featStim'])), flush=True)\n",
    "    # 返回字典\n",
    "    return dict_feature\n",
    "\n",
    "def split_zeros(feature_dict,feature_name,channels):\n",
    "    \"\"\"将ninapro_feature数据集中【restimulate】为0的部分（受试者不做动作）从数据集中去除\n",
    "\n",
    "    Args:\n",
    "        feature_dict: 数据集字典\n",
    "        feature_name: 待处理的数据的keyvalue\n",
    "        channels: 待处理的数据的通道数\n",
    "\n",
    "    Returns:\n",
    "        [feature_split, labels] 去除0部分的数据，对于的label(numpy array)\n",
    "    \"\"\"\n",
    "    feature_split = None\n",
    "    index = []\n",
    "    for i in range(len(feature_dict['featStim'])):\n",
    "        if feature_dict['featStim'][i]!=0:\n",
    "            index.append(i)\n",
    "    # 重排元素\n",
    "    emg_temp = feature_dict[feature_name]\n",
    "    emg_temp = np.reshape(emg_temp,(-1,4,channels))\n",
    "    emg_temp = np.swapaxes(emg_temp,1,2)\n",
    "    # 去除0label\n",
    "    if(feature_split is None):\n",
    "        feature_split = emg_temp[index,:,:]\n",
    "        labels = feature_dict['featStim'][index,:]\n",
    "    else:\n",
    "        feature_split = np.vstack((feature_split,emg_temp[index,:,:])) \n",
    "        labels = np.vstack((labels,feature_dict['featStim'][index,:]))\n",
    "    return feature_split, labels\n",
    "\n",
    "# 对多组数据合并，预处理\n",
    "def merge_multisubject(b,e):\n",
    "    \"\"\"将多组数据从mat文件中提取出来，预处理后合并\n",
    "\n",
    "    Args:\n",
    "        b: 开始的受试者序号\n",
    "        e: 结束的受试者序号\n",
    "\n",
    "    Returns:\n",
    "        [emg,acc,gyro,mag,labels]肌电c12，加速度c36，角速度c36，磁强c36数据和标签。\n",
    "    \"\"\"\n",
    "    emg_feature = None\n",
    "    labels = None\n",
    "    # 遍历受试者序号\n",
    "    for i in range(b,e+1):\n",
    "        emg_dict = get_feature_dict(\"../feature/DWTfeature_S{0}.mat\".format(i))\n",
    "        # 寻找动作为0的元素并剔除\n",
    "        emg,labels = split_zeros(emg_dict,'feat_set',12)\n",
    "        #print('delete 0 label,',emg_temp[index,:,:].shape)\n",
    "    # s = [1.28889041e-05, 0.00000000e+00, 1.72402617e+01, 1.57331247e+01, 2.11883893e-03]\n",
    "    # 归一化\n",
    "    # for i in range(5):\n",
    "    #     #s[i] = np.sum(np.abs(emg_feature[:,:,i]))/emg_feature[:,:,i].size\n",
    "    #     #print(\"avg=\",s)\n",
    "    #     if(s[i]!=0):\n",
    "    #         emg_feature[:,:,i] /= s[i]\n",
    "    #         emg_feature[:,:,i] -= 0.5*s[i]\n",
    "    return emg, labels\n",
    "\n",
    "# 读取2组数据分布作训练集和验证集\n",
    "emg_feature,labels = merge_multisubject(1,1)\n",
    "emg_feature_test,labels_test = merge_multisubject(2,2)\n",
    "print(emg_feature.shape,labels.shape)\n",
    "print(emg_feature_test.shape,labels.shape)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(97210, 48) (97210,)\n",
      "(105147, 48) (105147,)\n",
      "training...\n",
      "0.8795802900936117\n",
      "0.12319895004137066\n"
     ]
    }
   ],
   "source": [
    "# 训练集预处理\n",
    "emg_feature = np.reshape(emg_feature,(-1,48))\n",
    "labels = labels.flatten() - 1\n",
    "print(emg_feature.shape,labels.shape)\n",
    "\n",
    "# 数据集预处理\n",
    "emg_feature_test = np.reshape(emg_feature_test,(-1,48))\n",
    "labels_test = labels_test.flatten() - 1\n",
    "print(emg_feature_test.shape,labels_test.shape)\n",
    "\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "# 随机森林训练\n",
    "print('training...')\n",
    "model = RandomForestClassifier(random_state=0,max_depth = 10)\n",
    "model.fit(emg_feature, labels)\n",
    "\n",
    "# train_acc\n",
    "score_a = model.score(emg_feature,labels)\n",
    "print(score_a)\n",
    "# 随机森林验证\n",
    "score_t = model.score(emg_feature_test,labels_test)\n",
    "print(score_t)\n",
    "\n",
    "-a----         2022/12/7     22:19           7402 ninapro_utils.py\n",
    "-a----        2022/12/13      2:40          19014 sEMG_proc_CNNs.ipynb\n",
    "-a----        2022/12/13      2:46         197845 sEMG_proc_dataExtraction.ipynb\n",
    "-a----        2022/12/13     15:04           5972 sEMG_proc_DWT.ipynb\n",
    "-a----        2022/12/13      1:57          54737 sEMG_proc_LFM.ipynb\n",
    "-a----        2022/12/12     17:50          45352 semg_train.ipynb"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13 (default, Oct 19 2022, 22:38:03) [MSC v.1916 64 bit (AMD64)]"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "91eaa33755ee0e6c8927d7837736bb7bf44cb27baec87b08c7a4a0a98fc82110"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
